{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "q4WF3l23pumU"
      },
      "source": [
        "##### Copyright 2018 The AdaNet Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {},
        "colab_type": "code",
        "id": "Kic2quJWppmx"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "aL7SpaKdirqG"
      },
      "source": [
        "# Customizing AdaNet\n",
        "\n",
        "Often times, as a researcher or machine learning practitioner, you will have\n",
        "some prior knowledge about a dataset. Ideally you should be able to encode that\n",
        "knowledge into your machine learning algorithm. With `adanet`, you can do so by\n",
        "defining the *neural architecture search space* that the AdaNet algorithm should\n",
        "explore.\n",
        "\n",
        "In this tutorial, we will explore the flexibility of the `adanet` framework, and\n",
        "create a custom search space for an image-classificatio dataset using high-level\n",
        "TensorFlow libraries like `tf.layers`.\n",
        "\n"]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "x_3b6xx2s6B9"
      },
      "outputs": [],
      "source": [
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import functools\n",
        "\n",
        "import adanet\n",
        "from adanet.examples import simple_dnn\n",
        "import tensorflow as tf\n",
        "\n",
        "# The random seed to use.\n",
        "RANDOM_SEED = 42"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7gE5Mm9j2oYw"
      },
      "source": [
        "## Fashion MNIST dataset\n",
        "\n",
        "In this example, we will use the Fashion MNIST dataset\n",
        "[[Xiao et al., 2017](https://arxiv.org/abs/1708.07747)] for classifying fashion\n",
        "apparel images into one of ten categories:\n",
        "\n",
        "1.  T-shirt/top\n",
        "2.  Trouser\n",
        "3.  Pullover\n",
        "4.  Dress\n",
        "5.  Coat\n",
        "6.  Sandal\n",
        "7.  Shirt\n",
        "8.  Sneaker\n",
        "9.  Bag\n",
        "10. Ankle boot\n",
        "\n",
        "![Fashion MNIST](https://github.com/zalandoresearch/fashion-mnist/blob/master/doc/img/fashion-mnist-sprite.png?raw=true)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "5_hRtdchqRZb"
      },
      "source": [
        "## Download the data\n",
        "\n",
        "Conveniently, the data is available via Keras:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "uYklOnPJ4h7g"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (x_test, y_test) = (\n",
        "    tf.keras.datasets.fashion_mnist.load_data())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "tECo5dFd4QCa"
      },
      "source": [
        "## Supply the data in TensorFlow\n",
        "\n",
        "Our first task is to supply the data in TensorFlow. Using the\n",
        "tf.estimator.Estimator covention, we will define a function that returns an\n",
        "`input_fn` which returns feature and label `Tensors`.\n",
        "\n",
        "We will also use the `tf.data.Dataset` API to feed the data into our models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "gxTAoIXwsTH7"
      },
      "outputs": [],
      "source": [
        "FEATURES_KEY = \"images\"\n",
        "\n",
        "\n",
        "def generator(images, labels):\n",
        "  \"\"\"Returns a generator that returns image-label pairs.\"\"\"\n",
        "\n",
        "  def _gen():\n",
        "    for image, label in zip(images, labels):\n",
        "      yield image, label\n",
        "\n",
        "  return _gen\n",
        "\n",
        "\n",
        "def preprocess_image(image, label):\n",
        "  \"\"\"Preprocesses an image for an `Estimator`.\"\"\"\n",
        "  # First let's scale the pixel values to be between 0 and 1.\n",
        "  image = image / 255.\n",
        "  # Next we reshape the image so that we can apply a 2D convolution to it.\n",
        "  image = tf.reshape(image, [28, 28, 1])\n",
        "  # Finally the features need to be supplied as a dictionary.\n",
        "  features = {FEATURES_KEY: image}\n",
        "  return features, label\n",
        "\n",
        "\n",
        "def input_fn(partition, training, batch_size):\n",
        "  \"\"\"Generate an input_fn for the Estimator.\"\"\"\n",
        "\n",
        "  def _input_fn():\n",
        "    if partition == \"train\":\n",
        "      dataset = tf.data.Dataset.from_generator(\n",
        "          generator(x_train, y_train), (tf.float32, tf.int32), ((28, 28), ()))\n",
        "    else:\n",
        "      dataset = tf.data.Dataset.from_generator(\n",
        "          generator(x_test, y_test), (tf.float32, tf.int32), ((28, 28), ()))\n",
        "\n",
        "    # We call repeat after shuffling, rather than before, to prevent separate\n",
        "    # epochs from blending together.\n",
        "    if training:\n",
        "      dataset = dataset.shuffle(10 * batch_size, seed=RANDOM_SEED).repeat()\n",
        "\n",
        "    dataset = dataset.map(preprocess_image).batch(batch_size)\n",
        "    iterator = dataset.make_one_shot_iterator()\n",
        "    features, labels = iterator.get_next()\n",
        "    return features, labels\n",
        "\n",
        "  return _input_fn"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "vm9yudEv5lQZ"
      },
      "source": [
        "## Establish baselines\n",
        "\n",
        "The next task should be to get somes baselines to see how our model performs on\n",
        "this dataset.\n",
        "\n",
        "Let's define some information to share with all our `tf.estimator.Estimators`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "xNwSUWh-9_Ib"
      },
      "outputs": [],
      "source": [
        "# The number of classes.\n",
        "NUM_CLASSES = 10\n",
        "\n",
        "# We will average the losses in each mini-batch when computing gradients.\n",
        "loss_reduction = tf.losses.Reduction.SUM_OVER_BATCH_SIZE\n",
        "\n",
        "# A `Head` instance defines the loss function and metrics for `Estimators`.\n",
        "head = tf.contrib.estimator.multi_class_head(\n",
        "    NUM_CLASSES, loss_reduction=loss_reduction)\n",
        "\n",
        "# Some `Estimators` use feature columns for understanding their input features.\n",
        "feature_columns = [\n",
        "    tf.feature_column.numeric_column(FEATURES_KEY, shape=[28, 28, 1])\n",
        "]\n",
        "\n",
        "# Estimator configuration.\n",
        "config = tf.estimator.RunConfig(\n",
        "    save_checkpoints_steps=50000,\n",
        "    save_summary_steps=50000,\n",
        "    tf_random_seed=RANDOM_SEED)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "QY0cv-ot-Gxs"
      },
      "source": [
        "Let's start simple, and train a linear model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "both",
        "colab": {
          "height": 53,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 32813,
          "status": "ok",
          "timestamp": 1534440488365,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "s8wJKsi06blX",
        "outputId": "5acd325b-1de9-4c21-d825-8f9c10ffdd6c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.8413\n",
            "Loss: 0.464809\n"
          ]
        }
      ],
      "source": [
        "#@test {\"skip\": true}\n",
        "#@title Parameters\n",
        "LEARNING_RATE = 0.001  #@param {type:\"number\"}\n",
        "TRAIN_STEPS = 5000  #@param {type:\"integer\"}\n",
        "BATCH_SIZE = 64  #@param {type:\"integer\"}\n",
        "\n",
        "estimator = tf.estimator.LinearClassifier(\n",
        "    feature_columns=feature_columns,\n",
        "    n_classes=NUM_CLASSES,\n",
        "    optimizer=tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE),\n",
        "    loss_reduction=loss_reduction,\n",
        "    config=config)\n",
        "\n",
        "results, _ = tf.estimator.train_and_evaluate(\n",
        "    estimator,\n",
        "    train_spec=tf.estimator.TrainSpec(\n",
        "        input_fn=input_fn(\"train\", training=True, batch_size=BATCH_SIZE),\n",
        "        max_steps=TRAIN_STEPS),\n",
        "    eval_spec=tf.estimator.EvalSpec(\n",
        "        input_fn=input_fn(\"test\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None))\n",
        "print(\"Accuracy:\", results[\"accuracy\"])\n",
        "print(\"Loss:\", results[\"average_loss\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "a-1hE03c7_Yj"
      },
      "source": [
        "The linear model with default parameters achieves about **84.13% accuracy**.\n",
        "\n",
        "Let's see if we can do better with the `simple_dnn` AdaNet:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 53,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 51048,
          "status": "ok",
          "timestamp": 1534440539502,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "9fAoRYd19eUs",
        "outputId": "0b5678f1-4d44-430b-d84d-d3315366d4b8"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.8566\n",
            "Loss: 0.408646\n"
          ]
        }
      ],
      "source": [
        "#@test {\"skip\": true}\n",
        "#@title Parameters\n",
        "LEARNING_RATE = 0.003  #@param {type:\"number\"}\n",
        "TRAIN_STEPS = 5000  #@param {type:\"integer\"}\n",
        "BATCH_SIZE = 64  #@param {type:\"integer\"}\n",
        "ADANET_ITERATIONS = 2  #@param {type:\"integer\"}\n",
        "\n",
        "estimator = adanet.Estimator(\n",
        "    head=head,\n",
        "    subnetwork_generator=simple_dnn.Generator(\n",
        "        feature_columns=feature_columns,\n",
        "        optimizer=tf.train.RMSPropOptimizer(learning_rate=LEARNING_RATE),\n",
        "        seed=RANDOM_SEED),\n",
        "    max_iteration_steps=TRAIN_STEPS // ADANET_ITERATIONS,\n",
        "    evaluator=adanet.Evaluator(\n",
        "        input_fn=input_fn(\"train\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None),\n",
        "    config=config)\n",
        "\n",
        "results, _ = tf.estimator.train_and_evaluate(\n",
        "    estimator,\n",
        "    train_spec=tf.estimator.TrainSpec(\n",
        "        input_fn=input_fn(\"train\", training=True, batch_size=BATCH_SIZE),\n",
        "        max_steps=TRAIN_STEPS),\n",
        "    eval_spec=tf.estimator.EvalSpec(\n",
        "        input_fn=input_fn(\"test\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None))\n",
        "print(\"Accuracy:\", results[\"accuracy\"])\n",
        "print(\"Loss:\", results[\"average_loss\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ysWsJ3zXDwNx"
      },
      "source": [
        "The `simple_dnn` AdaNet model with default parameters achieves about **85.66%\n",
        "accuracy**.\n",
        "\n",
        "This improvement can be attributed to `simple_dnn` searching over\n",
        "fully-connected neural networks which have more expressive power than the linear\n",
        "model due to their non-linear activations.\n",
        "\n",
        "Fully-connected layers are permutation invariant to their inputs, meaning that\n",
        "if we consistently swapped two pixels before training, the final model would\n",
        "perform identically. However, there is spatial and locality information in\n",
        "images that we should try to capture. Applying a few convolutions to our inputs\n",
        "will allow us to do so, and that will require defining a custom\n",
        "`adanet.subnetwork.Builder` and `adanet.subnetwork.Generator`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "D3IE6-9vFVlg"
      },
      "source": [
        "## Define a convolutional AdaNet model\n",
        "\n",
        "Creating a new search space for AdaNet to explore is straightforward. There are\n",
        "two abstract classes you need to extend:\n",
        "\n",
        "1.  `adanet.subnetwork.Builder`\n",
        "2.  `adanet.subnetwork.Generator`\n",
        "\n",
        "Similar to the tf.estimator.Estimator `model_fn`, `adanet.subnetwork.Builder`\n",
        "allows you to define your own TensorFlow graph for creating a neural network,\n",
        "and specify the training operations.\n",
        "\n",
        "Below we define one that applies a 2D convolution, max-pooling, and then a\n",
        "fully-connected layer to the images:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "IsYJ97tRwBkt"
      },
      "outputs": [],
      "source": [
        "class SimpleCNNBuilder(adanet.subnetwork.Builder):\n",
        "  \"\"\"Builds a CNN subnetwork for AdaNet.\"\"\"\n",
        "\n",
        "  def __init__(self, learning_rate, max_iteration_steps, seed):\n",
        "    \"\"\"Initializes a `SimpleCNNBuilder`.\n",
        "\n",
        "    Args:\n",
        "      learning_rate: The float learning rate to use.\n",
        "      max_iteration_steps: The number of steps per iteration.\n",
        "      seed: The random seed.\n",
        "\n",
        "    Returns:\n",
        "      An instance of `SimpleCNNBuilder`.\n",
        "    \"\"\"\n",
        "    self._learning_rate = learning_rate\n",
        "    self._max_iteration_steps = max_iteration_steps\n",
        "    self._seed = seed\n",
        "\n",
        "  def build_subnetwork(self,\n",
        "                       features,\n",
        "                       logits_dimension,\n",
        "                       training,\n",
        "                       iteration_step,\n",
        "                       summary,\n",
        "                       previous_ensemble=None):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "    images = features.values()[0]\n",
        "    kernel_initializer = tf.keras.initializers.he_normal(seed=self._seed)\n",
        "    x = tf.layers.conv2d(\n",
        "        images,\n",
        "        filters=16,\n",
        "        kernel_size=3,\n",
        "        padding=\"same\",\n",
        "        activation=\"relu\",\n",
        "        kernel_initializer=kernel_initializer)\n",
        "    x = tf.layers.max_pooling2d(x, pool_size=2, strides=2)\n",
        "    x = tf.layers.flatten(x)\n",
        "    x = tf.layers.dense(\n",
        "        x, units=64, activation=\"relu\", kernel_initializer=kernel_initializer)\n",
        "\n",
        "    # The `Head` passed to adanet.Estimator will apply the softmax activation.\n",
        "    logits = tf.layers.dense(\n",
        "        x, units=10, activation=None, kernel_initializer=kernel_initializer)\n",
        "\n",
        "    # Use a constant complexity measure, since all subnetworks have the same\n",
        "    # architecture and hyperparameters.\n",
        "    complexity = tf.constant(1)\n",
        "\n",
        "    return adanet.Subnetwork(\n",
        "        last_layer=x,\n",
        "        logits=logits,\n",
        "        complexity=complexity,\n",
        "        persisted_tensors={})\n",
        "\n",
        "  def build_subnetwork_train_op(self, \n",
        "                                subnetwork, \n",
        "                                loss, \n",
        "                                var_list, \n",
        "                                labels, \n",
        "                                iteration_step,\n",
        "                                summary, \n",
        "                                previous_ensemble=None):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "\n",
        "    # Momentum optimizer with cosine learning rate decay works well with CNNs.\n",
        "    learning_rate = tf.train.cosine_decay(\n",
        "        learning_rate=self._learning_rate,\n",
        "        global_step=iteration_step,\n",
        "        decay_steps=self._max_iteration_steps)\n",
        "    optimizer = tf.train.MomentumOptimizer(learning_rate, .9)\n",
        "    # NOTE: The `adanet.Estimator` increments the global step.\n",
        "    return optimizer.minimize(loss=loss, var_list=var_list)\n",
        "\n",
        "  def build_mixture_weights_train_op(self, loss, var_list, logits, labels,\n",
        "                                     iteration_step, summary):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "    return tf.no_op(\"mixture_weights_train_op\")\n",
        "\n",
        "  @property\n",
        "  def name(self):\n",
        "    \"\"\"See `adanet.subnetwork.Builder`.\"\"\"\n",
        "    return \"simple_cnn\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "OFamPrZHJ5ii"
      },
      "source": [
        "Next, we extend a `adanet.subnetwork.Generator`, which defines the search\n",
        "space of candidate `SimpleCNNBuilders` to consider including the final network.\n",
        "It can create one or more at each iteration with different parameters, and the\n",
        "AdaNet algorithm will select the candidate that best improves the overall neural\n",
        "network's `adanet_loss` on the training set.\n",
        "\n",
        "The one below is very simple: it always creates the same architecture, but gives\n",
        "it a different random seed at each iteration:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "id": "-BAnb_XGwhRy"
      },
      "outputs": [],
      "source": [
        "class SimpleCNNGenerator(adanet.subnetwork.Generator):\n",
        "  \"\"\"Generates a `SimpleCNN` at each iteration.\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, learning_rate, max_iteration_steps, seed=None):\n",
        "    \"\"\"Initializes a `Generator` that builds `SimpleCNNs`.\n",
        "\n",
        "    Args:\n",
        "      learning_rate: The float learning rate to use.\n",
        "      max_iteration_steps: The number of steps per iteration.\n",
        "      seed: The random seed.\n",
        "\n",
        "    Returns:\n",
        "      An instance of `Generator`.\n",
        "    \"\"\"\n",
        "    self._seed = seed\n",
        "    self._dnn_builder_fn = functools.partial(\n",
        "        SimpleCNNBuilder,\n",
        "        learning_rate=learning_rate,\n",
        "        max_iteration_steps=max_iteration_steps)\n",
        "\n",
        "  def generate_candidates(self, previous_ensemble, iteration_number,\n",
        "                          previous_ensemble_reports, all_reports):\n",
        "    \"\"\"See `adanet.subnetwork.Generator`.\"\"\"\n",
        "    seed = self._seed\n",
        "    # Change the seed according to the iteration so that each subnetwork\n",
        "    # learns something different.\n",
        "    if seed is not None:\n",
        "      seed += iteration_number\n",
        "    return [self._dnn_builder_fn(seed=seed)]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "8sdvharsLJ1T"
      },
      "source": [
        "With these defined, we pass them into a new `adanet.Estimator`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "height": 53,
          "test": {
            "output": "ignore",
            "timeout": 900
          }
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 74112,
          "status": "ok",
          "timestamp": 1534440614010,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "-Fhi1SjkzVBt",
        "outputId": "fd4b8c9c-4665-473b-ea7d-cd47c8680811"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Accuracy: 0.9041\n",
            "Loss: 0.26544\n"
          ]
        }
      ],
      "source": [
        "#@title Parameters\n",
        "LEARNING_RATE = 0.05  #@param {type:\"number\"}\n",
        "TRAIN_STEPS = 5000  #@param {type:\"integer\"}\n",
        "BATCH_SIZE = 64  #@param {type:\"integer\"}\n",
        "ADANET_ITERATIONS = 2  #@param {type:\"integer\"}\n",
        "\n",
        "max_iteration_steps = TRAIN_STEPS // ADANET_ITERATIONS\n",
        "estimator = adanet.Estimator(\n",
        "    head=head,\n",
        "    subnetwork_generator=SimpleCNNGenerator(\n",
        "        learning_rate=LEARNING_RATE,\n",
        "        max_iteration_steps=max_iteration_steps,\n",
        "        seed=RANDOM_SEED),\n",
        "    max_iteration_steps=max_iteration_steps,\n",
        "    evaluator=adanet.Evaluator(\n",
        "        input_fn=input_fn(\"train\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None),\n",
        "    report_materializer=adanet.ReportMaterializer(\n",
        "        input_fn=input_fn(\"train\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None),\n",
        "    adanet_loss_decay=.99,\n",
        "    config=config)\n",
        "\n",
        "results, _ = tf.estimator.train_and_evaluate(\n",
        "    estimator,\n",
        "    train_spec=tf.estimator.TrainSpec(\n",
        "        input_fn=input_fn(\"train\", training=True, batch_size=BATCH_SIZE),\n",
        "        max_steps=TRAIN_STEPS),\n",
        "    eval_spec=tf.estimator.EvalSpec(\n",
        "        input_fn=input_fn(\"test\", training=False, batch_size=BATCH_SIZE),\n",
        "        steps=None))\n",
        "print(\"Accuracy:\", results[\"accuracy\"])\n",
        "print(\"Loss:\", results[\"average_loss\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "3wGtI-4_LRw1"
      },
      "source": [
        "Our `SimpleCNNGenerator` code achieves **90.41% accuracy**."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "TKhCzP65hGyS"
      },
      "source": [
        "## Conclusion and next steps\n",
        "\n",
        "In this tutorial, you learned how to customize `adanet` to encode your\n",
        "understanding of a particular dataset, and explore novel search spaces with\n",
        "AdaNet.\n",
        "\n",
        "One use-case that has worked for us at Google, has been to take a production\n",
        "model's TensorFlow code, convert it to into an `adanet.subnetwork.Builder`, and\n",
        "adaptively grow it into an ensemble. In many cases, this has given significant\n",
        "performance improvements.\n",
        "\n",
        "As an exercise, you can swap out the FASHION-MNIST with the MNIST handwritten\n",
        "digits dataset in this notebook using `tf.keras.datasets.mnist.load_data()`, and\n",
        "see how `SimpleCNN` performs."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "customizing_adanet.ipynb",
      "provenance": [],
      "version": "0.3.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
